-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Erase dtype and device #166
Conversation
…ed dtype and device parameters from Calculator
Done! I dropped Since I was not involved in the tuning code development, I would kindly ask you to pay special attention to the changes in that part to ensure I didn’t break anything. |
docs/src/references/changelog.rst
Outdated
@@ -32,6 +32,11 @@ Added | |||
* Require consistent ``dtype`` between ``positions`` and ``neighbor_distances`` in | |||
``Calculator`` classes and tuning functions. | |||
|
|||
Changed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can probably change it to
Changed | |
Removed |
and also remove our statements in changed about using dtypes everywhere...
pot_1 = pot_1.to(dtype=dtype) | ||
pot_2 = pot_2.to(dtype=dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add a comment why you need this here.
cell = torch.eye( | ||
3, | ||
device=self.potential.smearing.device, | ||
dtype=self.potential.smearing.dtype, | ||
) | ||
ns_mesh = torch.ones(3, dtype=int, device=cell.device) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I apply to .to
to the class, is this correctly passed to self.kspace_filter
and self.mesh_interpolator
?
src/torchpme/lib/splines.py
Outdated
|
||
# Calculate intervals | ||
intervals = x[1:] - x[:-1] | ||
dy = (y[1:] - y[:-1]) / intervals | ||
|
||
# Create zero boundary conditions (natural spline) | ||
d2y = torch.zeros_like(x, dtype=torch.float64) | ||
torch.zeros_like(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can remove it?
if potential.smearing is None: | ||
raise ValueError( | ||
"Must specify smearing to use a potential with P3MCalculator" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there is no test for it. Should be go into workflow
tests.
Also I think more and more that we need an Ewald
Base class and another "Base" class for the direct calculator.
I fact the EwaldCalculator can serve as a base class and PME and P3M only override the k-space method. But this is something for later...
A couple of PRs ago, we decided to include
dtype
anddevice
as explicit and obligatory parameters for bothcalculators
andpotentials
.Unfortunately, after thorough consideration of how typical pipelines are built, I concluded that we should abandon this design choice.
The main reason is that, in most cases, when working with an NN
model,
the preferred strategy is to first initialize themodel
and then move it to the desired device usingmodel.to(device)
.Since
torch-pme
is designed to be an internal part of themodel
, this creates a conflict. We initializedtype
anddevice
once, but when we later move themodel
to a differentdevice
, it undermines our prior device-checking logic.Luckily, since our entire pipeline is either a
torch.nn.Module
or its subclass, we can integrate it smoothly with models that change theirdevice
anddtype
. The key idea is to thoroughly rewrite the pipeline so that all newly created tensors during calculations are registered as buffers usingself.register_buffer
.This PR aims to achieve exactly that.
📚 Documentation preview 📚: https://torch-pme--166.org.readthedocs.build/en/166/